Який найпростіший спосіб перетворити тензор фігури (batch_size, висота, ширина), заповнений n значеннями, в тензор фігури (batch_size, n, height, width)? Я створив рішення нижче, але, схоже, є простіший і швидший спосіб зробити це def batch_tensor_to_onehot (tnsr, класи): tnsr = tnsr.unsqueeze (1) res = [] для cls в діапазоні (класи): res.append ((tnsr == cls) .long ()) повернути факел. cat (res, dim = 1)
2021-02-20 08:20:11
Ви можете використовувати torch.nn.functional.one_hot. Для Вашого випадку: a = torch.nn.functional.one_hot (tnsr, num_classes = класи) out = a.permute (0, 3, 1, 2) | Ви також можете використовувати Tensor.scatter_, який уникає .permute, але, мабуть, важче зрозуміти, ніж простий метод, запропонований @Alpha. def batch_tensor_to_onehot (tnsr, класи): result = torch.zeros (tnsr.shape [0], класи, * tnsr.shape [1:], dtype = torch.long, device = tnsr.device) result.scatter_ (1, tnsr.unsqueeze (1), 1) повернути результат Результати порівняльного аналізу Мені було цікаво і я вирішив порівняти ці три підходи. Я виявив, що, схоже, немає суттєвої відносної різниці між запропонованими методами щодо розміру, ширини або висоти партії. Відмінним фактором в першу чергу була кількість класів. Звичайно, як і в будь-якому еталоні, пробіг може відрізнятися. Орієнтовні показники були зібрані з використанням випадкових індексів та з використанням розміру партії, висоти, ширини = 100. Кожен експеримент повторювався 20 разів із повідомленням про середнє значення. Експеримент num_classes = 100 запускається один раз перед профілюванням для розминки. Результати центрального процесора показують, що оригінальний метод був, мабуть, найкращим для num_classes менше приблизно 30, тоді як для графічного процесора метод scatter_ здається найшвидшим. Тести, проведені на Ubuntu 18.04, NVIDIA 2060 Super, i7-9700K Код, що використовується для бенчмаркінгу, поданий нижче: імпортний факел з tqdm імпортувати tqdm час імпорту імпортувати matplotlib.pyplot як plt def batch_tensor_to_onehot_slavka (tnsr, класи): tnsr = tnsr.unsqueeze (1) res = [] для cls в діапазоні (класи): res.append ((tnsr == cls) .long ()) повернути факел. cat (res, dim = 1) def batch_tensor_to_onehot_alpha (tnsr, класи): result = torch.nn.functional.one_hot (tnsr, num_classes = класи) повернути результат. перестановка (0, 3, 1, 2) def batch_tensor_to_onehot_jodag (tnsr, класи): result = torch.zeros (tnsr.shape [0], класи, * tnsr.shape [1:], dtype = torch.long, device = tnsr.device) result.scatter_ (1, tnsr.unsqueeze (1), 1) повернути результат def main (): num_classes = [2, 10, 25, 50, 100] висота = 100 ширина = 100 bs = [100] * 20 для d у ['cpu', 'cuda']: times_slavka = [] times_alpha = [] times_jodag = [] розминка = Правда для c у tqdm ([num_classes [-1]] + num_classes, ncols = 0): tslavka = 0 тальфа = 0 tjodag = 0 для b в bs: tnsr = torch.randint (c, (b, висота, ширина)). до (device = d) t0 = time.time () y = batch_tensor_to_onehot_slavka (tnsr, c) torch.cuda.synchronize () tslavka + = time.time () - t0 якщо не розминка: times_slavka.append (tslavka / len (bs)) для b в bs: tnsr = torch.randint (c, (b, height, width)). to (device = d) t0 = time.time () y = batch_tensor_to_onehot_alpha (tnsr, c) torch.cuda.synchronize () тальфа + = час.час () - t0 якщо не розминка: times_alpha.append (тальфа / лен (bs)) для b в bs: tnsr = torch.randint (c, (b, height, width)). to (device = d) t0 = time.time () y = batch_tensor_to_onehot_jodag (tnsr, c) torch.cuda.synchronize () tjodag + = time.time () - t0 якщо не розминка: times_jodag.append (tjodag / len (bs)) розминка = Невірно fig = plt.figure () ax = fig.subplots () ax.plot (num_classes, times_slavka, label = 'Slavka-cat') ax.plot (num_classes, times_alpha, label = 'Alpha-one_hot') ax.plot (num_classes, times_jodag, label = 'jodag-scatter_') ax.set_xlabel ('num_classes') ax.set_ylabel ('час (і)') ax.set_title (f '{d} орієнтир') ax.legend () plt.savefig (f '{d} .png') plt.show () якщо __name__ == "__основна__": основний () | Ваша відповідь StackExchange.ifUsing ("редактор", function () { StackExchange.using ("externalEditor", function () { StackExchange.using ("фрагменти", function () { StackExchange.snippets.init (); }); }); }, "фрагменти коду"); StackExchange.ready (function () { var channelOptions = { теги: "" .split (""), id: "1" }; initTagRenderer ("". split (""), "" .split (""), channelOptions); StackExchange.using ("externalEditor", function () { // Потрібно запускати редактор після фрагментів, якщо фрагменти увімкнено якщо (StackExchange.settings.snippets.snippetsEnabled) { StackExchange.using ("фрагменти", function () { createEditor (); }); } ще { createEditor (); } }); функція createEditor () { StackExchange.prepareEditor ({ useStacksEditor: false, heartbeatType: 'відповідь', autoActivateHeartbeat: false, convertImagesToLinks: true, noModals: правда, showLowRepImageUploadWarning: true, репутаціяToPostImages: 10, bindNavPrevention: true, постфікс: "", imageUploader: { brandingHtml: "Працює на \ u003ca href = \" https: //imgur.com/ \ "\ u003e \ u003csvg class = \" svg-icon \ "width = \" 50 \ "height = \" 18 \ "viewBox = \ "0 0 50 18 \" fill = \ "none \" xmlns = \ "http: //www.w3.org/2000/svg \" \ u003e \ u003cpath d = \ "M46.1709 9.17788C46.1709 8.26454 46,2665 7,94324 47,1084 7.58816C47.4091 7,46349 47,7169 7,36433 48,0099 7.26993C48.9099 6,97997 49,672 6,73443 49,672 5.93063C49.672 5,22043 48,9832 4,61182 48,1414 4.61182C47.4335 4,61182 46,7256 4,91628 46,0943 5.50789C45.7307 4,9328 45,2525 4,66231 44,6595 4.66231C43.6264 4,66231 43,1481 5,28821 43.1481 6.59048V11.9512C43.1481 13.2535 43.6264 13.8962 44.6595 13.8962C45.6924 13.8962 46.1709 13.253546.1709 11.9512V9.17788Z \ "/ \ u003e \ u003cpath d = \" M32.492 10.1419C32.492 12.6954 34.1182 14.0484 37.0451 14.0484C39.9723 14.0484 41.5985 12.6954 41.5985 10.1419V6.59049C66.5324 4.6213239 4.9214393239421329394213293942132939421639324393243932439621639639621639 38,5948 5,28821 38,5948 6,59049V9,60062C38,5948 10,8521 38,2696 11,5455 37,0451 11,5455C35,8209 11,5455 35,4954 10,8521 35,4954 9,60062V6,59049C35,4954 5,28821 35,0173 4,66232 34,0034 4,6232 fill-rule = \ "evenodd \" clip-rule = \ "evenodd \" d = \ "M25.6622 17.6335C27.8049 17.6335 29.3739 16.9402 30.2537 15.6379C30.8468 14.7755 30.9615 13.5579 30.9615 11.9512V6.59049C30.968 5.28821 3031631 29.4502 4.66231C28.9913 4.66231 28.4555 4.94978 28.1109 5.50789C27.499 4.86533 26.7335 4.56087 25.7005 4.56087C23.1369 4.56087 21.0134 6.57349 21.0134 9.27932C21.0134 11.9852 23.003 13.913 25.3754 13.913 12.913 13.913 13.913 13.913 C28. 1256 12.8854 28,1301 12,9342 28,1301 12.983C28.1301 14,4373 27,2502 15,2321 25,777 15.2321C24.8349 15,2321 24,1352 14,9821 23,5661 14.7787C23.176 14,6393 22,8472 14,5218 22,5437 14.5218C21.7977 14,5218 21,2429 15,0123 21,2429 15.6887C21.2429 16,7375 22,9072 17,6335 25,6622 17.6335ZM24.1317 9,27932 C24.1317 7.94324 24.9928 7.09766 26.1024 7.09766C27.2119 7.09766 28.0918 7.94324 28.0918 9.27932C28.0918 10.6321 27.2311 11.5116 26.1024 11.5116C24.9737 11.5116 24.1317 10.6491 24.1317 9.27900Zc ". 8045 13.2535 17,2637 13,8962 18,2965 13.8962C19.3298 13,8962 19,8079 13,2535 19,8079 11.9512V8.12928C19.8079 5,82936 18,4879 4,62866 16,4027 4.62866C15.1594 4,62866 14,279 4,98375 13,3609 5.88013C12.653 5,05154 11,6581 4,62866 10,3573 4.62866C9.34336 4,62866 8,57809 4,89931 7,9466 5.5079C7. 58314 4.9328 7.10506 4.66232 6.51203 4.66232C5.47873 4.66232 5.00066 5.28821 5.00066 6.59049V11.9512C5.00066 13.2535 5.47873 13.8962 6.51203 13.8962C7.54479 13.8962 8.0232 13 .2535 8.0232 11.9512V8.90741C8.0232 7.58817 8.44431 6.91179 9.53458 6.91179C10.5104 6.91179 10.893 7.58817 10.893 8.94108V11.9512C10.893 13.2535 11.3711 13.8962 12.4044 13.8962C13.4375 13.89.913 1379.913.9713 C16.4027 6.91179 16.8045 7.58817 16.8045 8.94108V11.9512Z \ "/ \ u003e \ u003cpath d = \" M3.31675 6.59049C3.31675 5.28821 2.83866 4.66232 1.82471 4.66232C0.791758 4.66232 0.313354 5.28821.313354 1,82471 13,8962C2,85798 13,8962 3,31675 13,2535 3,31675 11,9512V6,59049Z \ "/ \ u003e \ u003cpath d = \" M1,87209 0,400291C0,843612 0,429291 0 1,1159 0 1,98861C0 2,887869 0,8228769 3,77676 3,77676 3,67676 3,77676 3,67676 3,67676 3,77676 3,77676 3,67676 3,67676 3,67676 3,67676 3,67676 3,77676 3,776 C3.7234 1.1159 2.90056 0.400291 1.87209 0.400291Z \ "fill = \" # 1BB76E \ "/ \ u003e \ u003c / svg \ u003e \ u003c / a \ u003e", contentPolicyHtml: "Внески користувачів, ліцензовані під \ u003ca href = \" https: //stackoverflow.com/help/licensing \ "\ u003ecc by-sa \ u003c / a \ u003e \ u003ca href = \" https://stackoverflow.com / legal / content-policy \ "\ u003e (політика щодо вмісту) \ u003c / a \ u003e", allowUrls: true }, onDemand: правда, discardSelector: ".discard-answer" , odmahShowMarkdownHelp: true, enableTables: true, enableSnippets: true }); } }); Дякуємо за надання відповіді на Stack Overflow! Будь ласка, не забудьте відповісти на питання. Надайте деталі та поділіться своїми дослідженнями! Але уникайте ... Прохання про допомогу, роз’яснення або відповідь на інші відповіді. Складання заяв на основі думки; підкріпіть їх посиланнями або особистим досвідом. Щоб дізнатись більше, перегляньте наші поради щодо написання чудових відповідей. Чернетку збережено Чернетку відкинуто Зареєструйтесь або увійдіть StackExchange.ready (function () { StackExchange.helpers.onClickDraftSave ('# login-link'); }); Зареєструйтесь за допомогою Google Зареєструйтесь за допомогою Facebook Зареєструйтесь за допомогою електронної пошти та пароля Подати Опублікувати в якості гостя Ім'я Електронна пошта Обов’язково, але ніколи не показується StackExchange.ready ( function () { StackExchange.openid.initPostLogin ('. New-post-login', 'https% 3a% 2f% 2fstackoverflow.com% 2fquestions% 2f62245173% 2fpytorch-transform-tensor-to-one-hot% 23new-answer', 'question_page' ); } ); Опублікувати в якості гостя Ім'я Електронна пошта Обов’язково, але ніколи не показується Опублікуйте свою відповідь Викинути Натискаючи «Опублікувати свою відповідь», ви погоджуєтесь з нашими умовами надання послуг, політикою конфіденційності та політикою файлів cookie Не відповідь, яку ви шукаєте? Перегляньте інші запитання, позначені тегом python pytorch tensor one-hot-encoding, або задайте власне запитання.